{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "import pytorch_mask_rcnn as pmr\n",
    "\n",
    "\n",
    "use_cuda = True\n",
    "dataset = \"voc\"\n",
    "ckpt_path = \"../ckpt/maskrcnn_voc-5.pth\"\n",
    "data_dir = \"E:/PyTorch/data/voc2012/\"\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() and use_cuda else \"cpu\")\n",
    "if device.type == \"cuda\":\n",
    "    pmr.get_gpu_prop(show=True)\n",
    "print(\"\\ndevice: {}\".format(device))\n",
    "\n",
    "ds = pmr.datasets(dataset, data_dir, \"val\", train=True)\n",
    "indices = torch.randperm(len(ds)).tolist()\n",
    "d = torch.utils.data.Subset(ds, indices)\n",
    "\n",
    "model = pmr.maskrcnn_resnet50(True, len(ds.classes) + 1).to(device)\n",
    "model.eval()\n",
    "model.head.score_thresh = 0.3\n",
    "\n",
    "if ckpt_path:\n",
    "    checkpoint = torch.load(ckpt_path, map_location=device)\n",
    "    model.load_state_dict(checkpoint[\"model\"])\n",
    "    print(checkpoint[\"eval_info\"])\n",
    "    del checkpoint\n",
    "    torch.cuda.empty_cache()\n",
    "    \n",
    "for p in model.parameters():\n",
    "    p.requires_grad_(False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "iters = 3\n",
    "\n",
    "for i, (image, target) in enumerate(d):\n",
    "    image = image.to(device)\n",
    "    target = {k: v.to(device) for k, v in target.items()}\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        result = model(image)\n",
    "        \n",
    "    plt.figure(figsize=(12, 15))\n",
    "    pmr.show(image, result, ds.classes)\n",
    "\n",
    "    if i >= iters - 1:\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:pytorch]",
   "language": "python",
   "name": "conda-env-pytorch-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
